"""
This script contains classes which define the dataloaders for error-guided finetuning algorithm.

The detailed description of this algorithm can be found in our paper.
"""
import torch
import random
import numpy as np

from tools.utils import *
from scipy.io import loadmat
from tools.dataloader_for_LFZSSR import *

class DataloaderForLFZSSRWithBatch_ErrorGuide:
    """
    This dataloader is a modified version of DataloaderForLFZSSRWithBatch.

    The modifications include:
    1. Use probability map generated by error map for patch selection guidance.
    2. Remove the scale augmentation part.

    The probability map is generated by function QuickTest in the main class.
    We input the LLR light field and generate the super-resolved LLR target view for error calculation.
    In this way, we get the LR error map and normalize it to probability map as the input of the dataloader.
    """
    def __init__(self,
                 mat_path,
                 refPos,
                 scale,
                 view_num,
                 patch_size=64,
                 batch_size=8,
                 random_flip_vertical=True,
                 random_flip_horizontal=True,
                 random_rotation=True):
        # get the data
        self.mat_data = loadmat(mat_path)
        self.lf_hr = self.mat_data['lf_hr']
        self.lf_hr = lf_modcrop(self.lf_hr, scale)
        self.lf_lr = LF_downscale(self.lf_hr, scale)

        self.scale = scale
        self.refPos = refPos
        self.view_num = view_num

        self.patch_size = patch_size
        self.batch_size = batch_size

        self.random_flip_vertical = random_flip_vertical
        self.random_flip_horizontal = random_flip_horizontal
        self.random_rotation = random_rotation

        self.ori_view_num = self.lf_hr.shape[0]
        self.lr_height = self.lf_lr.shape[2]
        self.lr_width = self.lf_lr.shape[3]
        self.resize = bicubic_imresize()

    def get_patch(self, prob_map):
        # The probability map should be a normalized version of error_map, which is all-positive with an L1 error
        # calculation.
        view_start = (self.ori_view_num - self.view_num) // 2
        lf_lr = self.lf_lr[view_start: view_start + self.view_num, view_start: view_start + self.view_num]

        hr_father = torch.Tensor(lf_lr.astype(np.float32) / 255.0).contiguous().view(1, -1, self.lr_height, self.lr_width)

        father_hei, father_wid = hr_father.shape[2], hr_father.shape[3]
        father_hei = father_hei - (father_hei % self.scale)
        father_wid = father_wid - (father_wid % self.scale)

        hr_father = hr_father[:, :, :father_hei, :father_wid]
        lr_son = self.resize(hr_father, 1.0 / self.scale)
        son_hei, son_wid = lr_son.shape[2], lr_son.shape[3]

        # get them back to numpy
        hr_father = hr_father.view(self.view_num, -1, father_hei, father_wid)
        hr_father = hr_father.numpy()
        lr_son = lr_son.view(self.view_num, -1, son_hei, son_wid)
        lr_son = lr_son.numpy()

        # augmentation flags
        self.aug_vertical = np.random.rand(self.batch_size)
        self.aug_horizontal = np.random.rand(self.batch_size)
        self.aug_rotation = np.random.randint(low=1, high=5, size=(self.batch_size))

        self.prob_map_total = prob_map
        # get the indices for cropping
        self.crop_indices = self.make_list_of_crop_indices()
        # use the indices to get the 2D positions of selected patches
        self.position_stack = self.get_top_left(self.batch_size, self.patch_size)

        # get the patch pairs by the selected 2D positions
        hr_patch, lr_patch = self.crop_slice(hr_father, lr_son)

        hr_ref_patch = hr_patch[:, self.refPos[0], self.refPos[1], :, :]
        hr_ref_patch = torch.Tensor(hr_ref_patch.copy()).contiguous()
        hr_ref_patch = hr_ref_patch.unsqueeze(1)

        lr_ref_patch = lr_patch[:, self.refPos[0], self.refPos[1], :, :]
        lr_ref_patch = torch.Tensor(lr_ref_patch.copy()).contiguous()
        lr_ref_patch = lr_ref_patch.unsqueeze(1)

        hr_patch = torch.Tensor(hr_patch.copy()).contiguous()
        hr_patch = hr_patch.view(self.batch_size, -1, hr_patch.shape[3], hr_patch.shape[4])
        lr_patch = torch.Tensor(lr_patch.copy()).contiguous()
        lr_patch = lr_patch.view(self.batch_size, -1, lr_patch.shape[3], lr_patch.shape[4])

        return lr_patch, hr_patch, lr_ref_patch, hr_ref_patch

    def make_list_of_crop_indices(self):
        crop_indices = np.random.choice(a=len(self.prob_map_total), size=self.batch_size, p=self.prob_map_total)
        return crop_indices

    def get_top_left(self, batch_size, patch_size):
        top_list, left_list = [], []
        for i in range(batch_size):
            center = self.crop_indices[i]
            row, col = int(center / self.lf_lr.shape[3]), center % self.lf_lr.shape[3]
            top, left = min(max(0, row - patch_size // 2), self.lf_lr.shape[2] - patch_size), min(max(0, col - patch_size // 2), self.lf_lr.shape[3] - patch_size)
            top_list.append(top - top % self.scale)
            left_list.append(left - left % self.scale)
        # the indices should be divided by scale for hr_father
        position_stack = np.stack([np.array(top_list), np.array(left_list)], axis=-1) # shape: [batch_size, 2]
        return position_stack

    def crop_slice(self, hr_father, lr_son):
        hr_patches = []
        lr_patches = []
        for i, position in enumerate(self.position_stack):
            hr_patch = hr_father[:, :, position[0]: position[0] + self.patch_size, position[1]: position[1] + self.patch_size]
            lr_patch = lr_son[:, :, (position[0] // self.scale): (position[0] + self.patch_size) // self.scale,
                       (position[1] // self.scale): (position[1] + self.patch_size) // self.scale]
            hr_patch, lr_patch = self.lf_augmentation(hr_patch, lr_patch, i)
            hr_patches.append(hr_patch)
            lr_patches.append(lr_patch)
        hr_patches = np.stack(hr_patches, axis=0)
        lr_patches = np.stack(lr_patches, axis=0) # [batch_size, U, V, h, w]
        return hr_patches, lr_patches

    def lf_augmentation(self, hr_patch, lr_patch, i):
        # hr_patch: [U,V,H,W]
        # lr_patch: [U,V,h,w]
        # i: index within the batch
        if self.random_flip_vertical and self.aug_vertical[i] > 0.5:
            hr_patch = np.flip(np.flip(hr_patch, 0), 2)
            lr_patch = np.flip(np.flip(lr_patch, 0), 2)
        if self.random_flip_horizontal and self.aug_horizontal[i] > 0.5:
            hr_patch = np.flip(np.flip(hr_patch, 1), 3)
            lr_patch = np.flip(np.flip(lr_patch, 1), 3)
        if self.random_rotation:
            r_ang = self.aug_rotation[i]
            hr_patch = np.rot90(hr_patch, r_ang, (2, 3))
            hr_patch = np.rot90(hr_patch, r_ang, (0, 1))
            lr_patch = np.rot90(lr_patch, r_ang, (2, 3))
            lr_patch = np.rot90(lr_patch, r_ang, (0, 1))
        return hr_patch, lr_patch


class DataloaderForAlignNetWithBatch_ErrorGuide:
    """
    This dataloader is a modified version of DataloaderForAlignNetWithBatch.
    We generate LR light field patches for finetuning with the guidance of LR probability map.

    The probability map is generated by TestAlignNet in the main class.
    We input the LR light field and get the aligned pre-upsampled light field for error calculation.
    Then we normalize it as the probability map.
    """
    def __init__(self,
                 mat_path,
                 refPos,
                 scale,
                 view_num,
                 patch_size=64,
                 batch_size=8,
                 random_flip_vertical=True,
                 random_flip_horizontal=True,
                 random_rotation=True):
        self.mat_data = loadmat(mat_path)

        self.lf_hr = self.mat_data['lf_hr']
        self.lf_hr = lf_modcrop(self.lf_hr, scale)
        self.lf_lr = LF_downscale(self.lf_hr, scale)

        self.refPos = refPos
        self.view_num = view_num

        self.patch_size = patch_size
        self.batch_size = batch_size

        self.random_flip_vertical = random_flip_vertical
        self.random_flip_horizontal = random_flip_horizontal
        self.random_rotation = random_rotation

        self.ori_view_num = self.lf_hr.shape[0]
        self.lr_height = self.lf_lr.shape[2]
        self.lr_width = self.lf_lr.shape[3]

    def get_patch(self, prob_map):
        view_start = (self.ori_view_num - self.view_num) // 2
        lf_lr = self.lf_lr[view_start: view_start + self.view_num, view_start: view_start + self.view_num]
        lf_lr = lf_lr.astype(np.float32) / 255.0

        self.aug_vertical = np.random.rand(self.batch_size)
        self.aug_horizontal = np.random.rand(self.batch_size)
        self.aug_rotation = np.random.randint(low=1, high=5, size=(self.batch_size))

        self.prob_map_total = prob_map
        self.crop_indices = self.make_list_of_crop_indices()
        self.position_stack = self.get_top_left(self.batch_size, self.patch_size)

        # get the patches
        lf_patch = self.crop_slice(lf_lr)

        lf_patch = torch.Tensor(lf_patch.copy()).contiguous()
        lf_patch = lf_patch.view(self.batch_size, -1, self.patch_size, self.patch_size)
        return lf_patch

    def make_list_of_crop_indices(self):
        crop_indices = np.random.choice(a=len(self.prob_map_total), size=self.batch_size, p=self.prob_map_total)
        return crop_indices

    def get_top_left(self, batch_size, patch_size):
        top_list, left_list = [], []
        for i in range(batch_size):
            center = self.crop_indices[i]
            row, col = int(center / self.lf_lr.shape[3]), center % self.lf_lr.shape[3]
            top, left = min(max(0, row - patch_size // 2), self.lf_lr.shape[2] - patch_size), min(
                max(0, col - patch_size // 2), self.lf_lr.shape[3] - patch_size)
            top_list.append(top)
            left_list.append(left)
        position_stack = np.stack([np.array(top_list), np.array(left_list)], axis=-1)  # shape: [batch_size, 2]
        return position_stack

    def crop_slice(self, lf_img):
        lf_patches = np.stack([self.lf_augmentation(lf_img[:, :, position[0]: position[0] + self.patch_size,
                                                    position[1]: position[1] + self.patch_size], i) for i, position in enumerate(self.position_stack)])
        return lf_patches

    def lf_augmentation(self, lf_patch, i):
        # lf_patch: [U,V,H,W]
        # i: index within the batch
        if self.random_flip_vertical and self.aug_vertical[i] > 0.5:
            lf_patch = np.flip(np.flip(lf_patch, 0), 2)
        if self.random_flip_horizontal and self.aug_horizontal[i] > 0.5:
            lf_patch = np.flip(np.flip(lf_patch, 1), 3)
        if self.random_rotation:
            r_ang = self.aug_rotation[i]
            lf_patch = np.rot90(lf_patch, r_ang, (2, 3))
            lf_patch = np.rot90(lf_patch, r_ang, (0, 1))
        return lf_patch